#include "statsxx/machine_learning/DBN.hpp"

// STL
#include <vector>  // std::vector<>

// jScience
#include "jScience/linalg.hpp" // Vector<>


// using namespace machine_learning;


//
// DESC: Forward propagate a visible vector through the entire DBN, getting all of the hidden units.
//
// NOTE: 
//
inline std::vector<Vector<double>> machine_learning::DBN::v_to_h(
                                                                 const Vector<double> &v
                                                                 ) const
{
    std::vector<Vector<double>> h(this->RBM.size());
    
    h[0] = this->RBM[0].v_to_h(
                               v
                               );
    
    for(auto i = 1; i < this->RBM.size(); ++i)
    {
        h[i] = this->RBM[i].v_to_h(
                                   h[i-1]
                                   );
    }

    return h;
}


//
// DESC: Wrapper to the above.
//
inline std::vector<Matrix<double>> machine_learning::DBN::v_to_h(
                                                                 const Matrix<double> &V
                                                                 ) const
{
    std::vector<Matrix<double>> H;
    
    for(auto i = 0; i < this->RBM.size(); ++i)
    {
        H.push_back(Matrix<double>(V.size(0),this->RBM[i].get_nh()));
    }
    
    for(auto i = 0; i < V.size(0); ++i)
    {
        std::vector<Vector<double>> h = v_to_h(
                                               V.row(i)
                                               );

        for(auto k = 0; k < h.size(); ++k)
        {
            for(auto j = 0; j < h[k].size(); ++j)
            {
                H[k](i,j) = h[k](j);
            }
        }
    }

    return H;
}